{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# A nearest neighbor learning algorithm example using TensorFlow library.\n",
    "# This example is using the MNIST database of handwritten digits\n",
    "# (http://yann.lecun.com/exdb/mnist/)\n",
    "\n",
    "# Author: Aymeric Damien\n",
    "# Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# Import MINST data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# In this example, we limit mnist data\n",
    "Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n",
    "Xte, Yte = mnist.test.next_batch(200) #200 for testing\n",
    "\n",
    "# tf Graph Input\n",
    "xtr = tf.placeholder(\"float\", [None, 784])\n",
    "xte = tf.placeholder(\"float\", [784])\n",
    "\n",
    "# Nearest Neighbor calculation using L1 Distance\n",
    "# Calculate L1 Distance\n",
    "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)\n",
    "# Prediction: Get min distance index (Nearest neighbor)\n",
    "pred = tf.arg_min(distance, 0)\n",
    "\n",
    "accuracy = 0.\n",
    "\n",
    "# Initializing the variables\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test 0 Prediction: 7 True Class: 7\n",
      "Test 1 Prediction: 2 True Class: 2\n",
      "Test 2 Prediction: 1 True Class: 1\n",
      "Test 3 Prediction: 0 True Class: 0\n",
      "Test 4 Prediction: 4 True Class: 4\n",
      "Test 5 Prediction: 1 True Class: 1\n",
      "Test 6 Prediction: 4 True Class: 4\n",
      "Test 7 Prediction: 9 True Class: 9\n",
      "Test 8 Prediction: 8 True Class: 5\n",
      "Test 9 Prediction: 9 True Class: 9\n",
      "Test 10 Prediction: 0 True Class: 0\n",
      "Test 11 Prediction: 0 True Class: 6\n",
      "Test 12 Prediction: 9 True Class: 9\n",
      "Test 13 Prediction: 0 True Class: 0\n",
      "Test 14 Prediction: 1 True Class: 1\n",
      "Test 15 Prediction: 5 True Class: 5\n",
      "Test 16 Prediction: 4 True Class: 9\n",
      "Test 17 Prediction: 7 True Class: 7\n",
      "Test 18 Prediction: 3 True Class: 3\n",
      "Test 19 Prediction: 4 True Class: 4\n",
      "Test 20 Prediction: 9 True Class: 9\n",
      "Test 21 Prediction: 6 True Class: 6\n",
      "Test 22 Prediction: 6 True Class: 6\n",
      "Test 23 Prediction: 5 True Class: 5\n",
      "Test 24 Prediction: 4 True Class: 4\n",
      "Test 25 Prediction: 0 True Class: 0\n",
      "Test 26 Prediction: 7 True Class: 7\n",
      "Test 27 Prediction: 4 True Class: 4\n",
      "Test 28 Prediction: 0 True Class: 0\n",
      "Test 29 Prediction: 1 True Class: 1\n",
      "Test 30 Prediction: 3 True Class: 3\n",
      "Test 31 Prediction: 1 True Class: 1\n",
      "Test 32 Prediction: 3 True Class: 3\n",
      "Test 33 Prediction: 4 True Class: 4\n",
      "Test 34 Prediction: 7 True Class: 7\n",
      "Test 35 Prediction: 2 True Class: 2\n",
      "Test 36 Prediction: 7 True Class: 7\n",
      "Test 37 Prediction: 1 True Class: 1\n",
      "Test 38 Prediction: 2 True Class: 2\n",
      "Test 39 Prediction: 1 True Class: 1\n",
      "Test 40 Prediction: 1 True Class: 1\n",
      "Test 41 Prediction: 7 True Class: 7\n",
      "Test 42 Prediction: 4 True Class: 4\n",
      "Test 43 Prediction: 1 True Class: 2\n",
      "Test 44 Prediction: 3 True Class: 3\n",
      "Test 45 Prediction: 5 True Class: 5\n",
      "Test 46 Prediction: 1 True Class: 1\n",
      "Test 47 Prediction: 2 True Class: 2\n",
      "Test 48 Prediction: 4 True Class: 4\n",
      "Test 49 Prediction: 4 True Class: 4\n",
      "Test 50 Prediction: 6 True Class: 6\n",
      "Test 51 Prediction: 3 True Class: 3\n",
      "Test 52 Prediction: 5 True Class: 5\n",
      "Test 53 Prediction: 5 True Class: 5\n",
      "Test 54 Prediction: 6 True Class: 6\n",
      "Test 55 Prediction: 0 True Class: 0\n",
      "Test 56 Prediction: 4 True Class: 4\n",
      "Test 57 Prediction: 1 True Class: 1\n",
      "Test 58 Prediction: 9 True Class: 9\n",
      "Test 59 Prediction: 5 True Class: 5\n",
      "Test 60 Prediction: 7 True Class: 7\n",
      "Test 61 Prediction: 8 True Class: 8\n",
      "Test 62 Prediction: 9 True Class: 9\n",
      "Test 63 Prediction: 3 True Class: 3\n",
      "Test 64 Prediction: 7 True Class: 7\n",
      "Test 65 Prediction: 4 True Class: 4\n",
      "Test 66 Prediction: 6 True Class: 6\n",
      "Test 67 Prediction: 4 True Class: 4\n",
      "Test 68 Prediction: 3 True Class: 3\n",
      "Test 69 Prediction: 0 True Class: 0\n",
      "Test 70 Prediction: 7 True Class: 7\n",
      "Test 71 Prediction: 0 True Class: 0\n",
      "Test 72 Prediction: 2 True Class: 2\n",
      "Test 73 Prediction: 7 True Class: 9\n",
      "Test 74 Prediction: 1 True Class: 1\n",
      "Test 75 Prediction: 7 True Class: 7\n",
      "Test 76 Prediction: 3 True Class: 3\n",
      "Test 77 Prediction: 7 True Class: 2\n",
      "Test 78 Prediction: 9 True Class: 9\n",
      "Test 79 Prediction: 7 True Class: 7\n",
      "Test 80 Prediction: 7 True Class: 7\n",
      "Test 81 Prediction: 6 True Class: 6\n",
      "Test 82 Prediction: 2 True Class: 2\n",
      "Test 83 Prediction: 7 True Class: 7\n",
      "Test 84 Prediction: 8 True Class: 8\n",
      "Test 85 Prediction: 4 True Class: 4\n",
      "Test 86 Prediction: 7 True Class: 7\n",
      "Test 87 Prediction: 3 True Class: 3\n",
      "Test 88 Prediction: 6 True Class: 6\n",
      "Test 89 Prediction: 1 True Class: 1\n",
      "Test 90 Prediction: 3 True Class: 3\n",
      "Test 91 Prediction: 6 True Class: 6\n",
      "Test 92 Prediction: 9 True Class: 9\n",
      "Test 93 Prediction: 3 True Class: 3\n",
      "Test 94 Prediction: 1 True Class: 1\n",
      "Test 95 Prediction: 4 True Class: 4\n",
      "Test 96 Prediction: 1 True Class: 1\n",
      "Test 97 Prediction: 7 True Class: 7\n",
      "Test 98 Prediction: 6 True Class: 6\n",
      "Test 99 Prediction: 9 True Class: 9\n",
      "Test 100 Prediction: 6 True Class: 6\n",
      "Test 101 Prediction: 0 True Class: 0\n",
      "Test 102 Prediction: 5 True Class: 5\n",
      "Test 103 Prediction: 4 True Class: 4\n",
      "Test 104 Prediction: 9 True Class: 9\n",
      "Test 105 Prediction: 9 True Class: 9\n",
      "Test 106 Prediction: 2 True Class: 2\n",
      "Test 107 Prediction: 1 True Class: 1\n",
      "Test 108 Prediction: 9 True Class: 9\n",
      "Test 109 Prediction: 4 True Class: 4\n",
      "Test 110 Prediction: 8 True Class: 8\n",
      "Test 111 Prediction: 7 True Class: 7\n",
      "Test 112 Prediction: 3 True Class: 3\n",
      "Test 113 Prediction: 9 True Class: 9\n",
      "Test 114 Prediction: 7 True Class: 7\n",
      "Test 115 Prediction: 9 True Class: 4\n",
      "Test 116 Prediction: 9 True Class: 4\n",
      "Test 117 Prediction: 4 True Class: 4\n",
      "Test 118 Prediction: 9 True Class: 9\n",
      "Test 119 Prediction: 7 True Class: 2\n",
      "Test 120 Prediction: 5 True Class: 5\n",
      "Test 121 Prediction: 4 True Class: 4\n",
      "Test 122 Prediction: 7 True Class: 7\n",
      "Test 123 Prediction: 6 True Class: 6\n",
      "Test 124 Prediction: 7 True Class: 7\n",
      "Test 125 Prediction: 9 True Class: 9\n",
      "Test 126 Prediction: 0 True Class: 0\n",
      "Test 127 Prediction: 5 True Class: 5\n",
      "Test 128 Prediction: 8 True Class: 8\n",
      "Test 129 Prediction: 5 True Class: 5\n",
      "Test 130 Prediction: 6 True Class: 6\n",
      "Test 131 Prediction: 6 True Class: 6\n",
      "Test 132 Prediction: 5 True Class: 5\n",
      "Test 133 Prediction: 7 True Class: 7\n",
      "Test 134 Prediction: 8 True Class: 8\n",
      "Test 135 Prediction: 1 True Class: 1\n",
      "Test 136 Prediction: 0 True Class: 0\n",
      "Test 137 Prediction: 1 True Class: 1\n",
      "Test 138 Prediction: 6 True Class: 6\n",
      "Test 139 Prediction: 4 True Class: 4\n",
      "Test 140 Prediction: 6 True Class: 6\n",
      "Test 141 Prediction: 7 True Class: 7\n",
      "Test 142 Prediction: 2 True Class: 3\n",
      "Test 143 Prediction: 1 True Class: 1\n",
      "Test 144 Prediction: 7 True Class: 7\n",
      "Test 145 Prediction: 1 True Class: 1\n",
      "Test 146 Prediction: 8 True Class: 8\n",
      "Test 147 Prediction: 2 True Class: 2\n",
      "Test 148 Prediction: 0 True Class: 0\n",
      "Test 149 Prediction: 1 True Class: 2\n",
      "Test 150 Prediction: 9 True Class: 9\n",
      "Test 151 Prediction: 9 True Class: 9\n",
      "Test 152 Prediction: 5 True Class: 5\n",
      "Test 153 Prediction: 5 True Class: 5\n",
      "Test 154 Prediction: 1 True Class: 1\n",
      "Test 155 Prediction: 5 True Class: 5\n",
      "Test 156 Prediction: 6 True Class: 6\n",
      "Test 157 Prediction: 0 True Class: 0\n",
      "Test 158 Prediction: 3 True Class: 3\n",
      "Test 159 Prediction: 4 True Class: 4\n",
      "Test 160 Prediction: 4 True Class: 4\n",
      "Test 161 Prediction: 6 True Class: 6\n",
      "Test 162 Prediction: 5 True Class: 5\n",
      "Test 163 Prediction: 4 True Class: 4\n",
      "Test 164 Prediction: 6 True Class: 6\n",
      "Test 165 Prediction: 5 True Class: 5\n",
      "Test 166 Prediction: 4 True Class: 4\n",
      "Test 167 Prediction: 5 True Class: 5\n",
      "Test 168 Prediction: 1 True Class: 1\n",
      "Test 169 Prediction: 4 True Class: 4\n",
      "Test 170 Prediction: 9 True Class: 4\n",
      "Test 171 Prediction: 7 True Class: 7\n",
      "Test 172 Prediction: 2 True Class: 2\n",
      "Test 173 Prediction: 3 True Class: 3\n",
      "Test 174 Prediction: 2 True Class: 2\n",
      "Test 175 Prediction: 1 True Class: 7\n",
      "Test 176 Prediction: 1 True Class: 1\n",
      "Test 177 Prediction: 8 True Class: 8\n",
      "Test 178 Prediction: 1 True Class: 1\n",
      "Test 179 Prediction: 8 True Class: 8\n",
      "Test 180 Prediction: 1 True Class: 1\n",
      "Test 181 Prediction: 8 True Class: 8\n",
      "Test 182 Prediction: 5 True Class: 5\n",
      "Test 183 Prediction: 0 True Class: 0\n",
      "Test 184 Prediction: 2 True Class: 8\n",
      "Test 185 Prediction: 9 True Class: 9\n",
      "Test 186 Prediction: 2 True Class: 2\n",
      "Test 187 Prediction: 5 True Class: 5\n",
      "Test 188 Prediction: 0 True Class: 0\n",
      "Test 189 Prediction: 1 True Class: 1\n",
      "Test 190 Prediction: 1 True Class: 1\n",
      "Test 191 Prediction: 1 True Class: 1\n",
      "Test 192 Prediction: 0 True Class: 0\n",
      "Test 193 Prediction: 4 True Class: 9\n",
      "Test 194 Prediction: 0 True Class: 0\n",
      "Test 195 Prediction: 1 True Class: 3\n",
      "Test 196 Prediction: 1 True Class: 1\n",
      "Test 197 Prediction: 6 True Class: 6\n",
      "Test 198 Prediction: 4 True Class: 4\n",
      "Test 199 Prediction: 2 True Class: 2\n",
      "Done!\n",
      "Accuracy: 0.92\n"
     ]
    }
   ],
   "source": [
    "# Launch the graph\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "\n",
    "    # loop over test data\n",
    "    for i in range(len(Xte)):\n",
    "        # Get nearest neighbor\n",
    "        nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})\n",
    "        # Get nearest neighbor class label and compare it to its true label\n",
    "        print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n",
    "            \"True Class:\", np.argmax(Yte[i])\n",
    "        # Calculate accuracy\n",
    "        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n",
    "            accuracy += 1./len(Xte)\n",
    "    print \"Done!\"\n",
    "    print \"Accuracy:\", accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
